{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FNdZ-kD0l78P"
   },
   "source": [
    "#  在单个 GPU 上针对自定义代码微调代码 LLM\n",
    "\n",
    "_作者: [Maria Khalusova](https://github.com/MKhalusova)_\n",
    "\n",
    "公开发布的代码 LLM，如 Codex、StarCoder 和 Code Llama，在生成遵循通用编程原则和语法的代码方面表现出色，但它们可能不符合组织的内部惯例，或者不了解某些特定的库。\n",
    "\n",
    "在这个 notebook 中，我们将展示如何微调代码 LLM 来更好的理解你们公司或组织的代码风格和习惯。由于代码 LLM 非常大，按照传统的微调方式可能会消耗大量资源。但不用担心！我们会教你一些技巧，让你只用单个 GPU 就能完成微调工作。\n",
    "\n",
    "\n",
    "## 数据集\n",
    "\n",
    "对于这个例子，我们选择了 GitHub 上 Hugging Face 的前 10 个公共仓库。我们已经排除了非代码文件，如图片、音频文件、演示文稿等。对于 Jupyter notebook，我们只保留了包含代码的单元格。生成的代码被存储为一个数据集，你可以在 Hugging Face Hub 上找到，位于 [`smangrul/hf-stack-v1`](https://huggingface.co/datasets/smangrul/hf-stack-v1)。它包含仓库 id、文件路径和文件内容。\n",
    "\n",
    "\n",
    "## 模型\n",
    "\n",
    "我们将微调 [`bigcode/starcoderbase-1b`](https://huggingface.co/bigcode/starcoderbase-1b) 模型，这是一个在 80 多种编程语言上训练的 10 亿参数模型。这是一个需要权限的模型，所以如果你计划使用这个确切模型运行这个 notebook，你需要在其模型页面上获得访问权限。登录你的 Hugging Face 帐户以执行此操作：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bPlCJYDK6vrF"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WMVe_c8q43Qo"
   },
   "source": [
    "为了开始，首先让我们安装所有必要的库。正如你所看到的，除了 `transformers` 和 `datasets`，我们还将使用 `peft`、`bitsandbytes` 和 `flash-attn` 来优化训练过程。\n",
    "\n",
    "通过采用参数高效的训练技术，我们可以在一张 A100 高内存 GPU 上运行这个 Notebook。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Fp7i8WMCjKJG"
   },
   "outputs": [],
   "source": [
    "!pip install -q transformers datasets peft bitsandbytes flash-attn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "16EdABzt3_Ig"
   },
   "source": [
    "现在让我们定义一些变量。请随意调整这些变量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hru3G-CLmqis"
   },
   "outputs": [],
   "source": [
    "MODEL=\"bigcode/starcoderbase-1b\" # Model checkpoint on the Hugging Face Hub\n",
    "DATASET=\"smangrul/hf-stack-v1\"   # Dataset on the Hugging Face Hub\n",
    "DATA_COLUMN=\"content\"            # Column name containing the code content\n",
    "\n",
    "SEQ_LENGTH=2048                  # Sequence length\n",
    "\n",
    "# Training arguments\n",
    "MAX_STEPS=2000                   # max_steps\n",
    "BATCH_SIZE=16                    # batch_size\n",
    "GR_ACC_STEPS=1                   # gradient_accumulation_steps\n",
    "LR=5e-4                          # learning_rate\n",
    "LR_SCHEDULER_TYPE=\"cosine\"       # lr_scheduler_type\n",
    "WEIGHT_DECAY=0.01                # weight_decay\n",
    "NUM_WARMUP_STEPS=30              # num_warmup_steps\n",
    "EVAL_FREQ=100                    # eval_freq\n",
    "SAVE_FREQ=100                    # save_freq\n",
    "LOG_FREQ=25                      # log_freq\n",
    "OUTPUT_DIR=\"peft-starcoder-lora-a100\" # output_dir\n",
    "BF16=True                        # bf16\n",
    "FP16=False                       # no_fp16\n",
    "\n",
    "# FIM trasformations arguments\n",
    "FIM_RATE=0.5                     # fim_rate\n",
    "FIM_SPM_RATE=0.5                 # fim_spm_rate\n",
    "\n",
    "# LORA\n",
    "LORA_R=8                         # lora_r\n",
    "LORA_ALPHA=32                    # lora_alpha\n",
    "LORA_DROPOUT=0.0                 # lora_dropout\n",
    "LORA_TARGET_MODULES=\"c_proj,c_attn,q_attn,c_fc,c_proj\"    # lora_target_modules\n",
    "\n",
    "# bitsandbytes config\n",
    "USE_NESTED_QUANT=True            # use_nested_quant\n",
    "BNB_4BIT_COMPUTE_DTYPE=\"bfloat16\"# bnb_4bit_compute_dtype\n",
    "\n",
    "SEED=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FyZSXTbJrcnC"
   },
   "outputs": [],
   "source": [
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoTokenizer,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    "    logging,\n",
    "    set_seed,\n",
    "    BitsAndBytesConfig,\n",
    ")\n",
    "\n",
    "set_seed(SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pO7F5L5AtKo1"
   },
   "source": [
    "## 准备数据"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1LmrIZqP0oUE"
   },
   "source": [
    "首先加载数据。由于数据集可能相当大，请确保启用流模式。流模式允许我们在遍历数据集时逐步加载数据，而不是一次性下载数据集的整个内容。\n",
    "\n",
    "我们将前 4000 个示例作为验证集，其余的全部作为训练数据。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4oJZvZb-1J88"
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "dataset = load_dataset(\n",
    "    DATASET,\n",
    "    data_dir=\"data\",\n",
    "    split=\"train\",\n",
    "    streaming=True,\n",
    ")\n",
    "\n",
    "valid_data = dataset.take(4000)\n",
    "train_data = dataset.skip(4000)\n",
    "train_data = train_data.shuffle(buffer_size=5000, seed=SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sLQ8t0LM2GR6"
   },
   "source": [
    "在这一步，数据集仍然包含任意长度的原始数据。为了训练，我们需要固定长度的输入。让我们创建一个可迭代的数据集，它可以从文本文件流中返回固定长度的 token 块。\n",
    "\n",
    "首先，让我们估计数据集中每个 token 的平均字符数，这将帮助我们稍后估计文本缓冲区中的 token 数量。默认情况下，我们只从数据集中取 400 个示例（`nb_examples`）。只使用整个数据集的一个子集可以减少计算成本，同时仍然提供了对整体字符到 token 比的合理估计。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KCiAvydztNsu",
    "outputId": "cabf7fd0-a922-4371-cbc6-60ee99ef7469"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 400/400 [00:10<00:00, 39.87it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The character to token ratio of the dataset is: 2.43\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)\n",
    "\n",
    "def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):\n",
    "    \"\"\"\n",
    "    Estimate the average number of characters per token in the dataset.\n",
    "    \"\"\"\n",
    "\n",
    "    total_characters, total_tokens = 0, 0\n",
    "    for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):\n",
    "        total_characters += len(example[data_column])\n",
    "        total_tokens += len(tokenizer(example[data_column]).tokens())\n",
    "\n",
    "    return total_characters / total_tokens\n",
    "\n",
    "\n",
    "chars_per_token = chars_token_ratio(train_data, tokenizer, DATA_COLUMN)\n",
    "print(f\"The character to token ratio of the dataset is: {chars_per_token:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6F13VGobB3Ma"
   },
   "source": [
    "字符到 token 的比也可以用作文本标记质量的一个指标。例如，字符到 token 的比为 1.0 意味着每个字符都由一个 token 表示，这并没有太多意义。表明标记化做得不好。在标准的英文文本中，一个 token 通常相当于大约四个字符，这意味着字符到 token 的比率大约是 4.0。我们可以预见在代码数据集中的比率会更低，但一般来说，2.0 到 3.5 之间的数字可以认为是足够好的。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rcwYFRPpwxea"
   },
   "source": [
    "**可选的 FIM 变换**\n",
    "自回归语言模型通常是从左到右生成序列的。通过应用 FIM 变换，模型也可以学习填充文本。详细信息可以看[\"Efficient Training of Language Models to Fill in the Middle\" 这篇论文](https://arxiv.org/pdf/2207.14255.pdf)了解这种技术。\n",
    "\n",
    "我们将在下面定义 FIM 变换，并在创建可迭代数据集时使用它们。然而，如果你想省略变换步骤，请将 `fim_rate` 设置为 0。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zmejYvEKw1E-"
   },
   "outputs": [],
   "source": [
    "import functools\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "# Helper function to get token ids of the special tokens for prefix, suffix and middle for FIM transformations.\n",
    "@functools.lru_cache(maxsize=None)\n",
    "def get_fim_token_ids(tokenizer):\n",
    "    try:\n",
    "        FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD = tokenizer.special_tokens_map[\"additional_special_tokens\"][1:5]\n",
    "        suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (\n",
    "            tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]\n",
    "        )\n",
    "    except KeyError:\n",
    "        suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = None, None, None, None\n",
    "    return suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id\n",
    "\n",
    "\n",
    "## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py\n",
    "def permute(\n",
    "    sample,\n",
    "    np_rng,\n",
    "    suffix_tok_id,\n",
    "    prefix_tok_id,\n",
    "    middle_tok_id,\n",
    "    pad_tok_id,\n",
    "    fim_rate=0.5,\n",
    "    fim_spm_rate=0.5,\n",
    "    truncate_or_pad=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes:\n",
    "    PSM and SPM (with a probability of fim_spm_rate).\n",
    "    \"\"\"\n",
    "\n",
    "    # The if condition will trigger with the probability of fim_rate\n",
    "    # This means FIM transformations will apply to samples with a probability of fim_rate\n",
    "    if np_rng.binomial(1, fim_rate):\n",
    "\n",
    "        # Split the sample into prefix, middle, and suffix, based on randomly generated indices stored in the boundaries list.\n",
    "        boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2))\n",
    "        boundaries.sort()\n",
    "\n",
    "        prefix = np.array(sample[: boundaries[0]], dtype=np.int64)\n",
    "        middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64)\n",
    "        suffix = np.array(sample[boundaries[1] :], dtype=np.int64)\n",
    "\n",
    "        if truncate_or_pad:\n",
    "            # calculate the new total length of the sample, taking into account tokens indicating prefix, middle, and suffix\n",
    "            new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3\n",
    "            diff = new_length - len(sample)\n",
    "\n",
    "            # trancate or pad if there's a difference in length between the new length and the original\n",
    "            if diff > 0:\n",
    "                if suffix.shape[0] <= diff:\n",
    "                    return sample, np_rng\n",
    "                suffix = suffix[: suffix.shape[0] - diff]\n",
    "            elif diff < 0:\n",
    "                suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)])\n",
    "\n",
    "        # With the probability of fim_spm_rateapply SPM variant of FIM transformations\n",
    "        # SPM: suffix, prefix, middle\n",
    "        if np_rng.binomial(1, fim_spm_rate):\n",
    "            new_sample = np.concatenate(\n",
    "                [\n",
    "                    [prefix_tok_id, suffix_tok_id],\n",
    "                    suffix,\n",
    "                    [middle_tok_id],\n",
    "                    prefix,\n",
    "                    middle,\n",
    "                ]\n",
    "            )\n",
    "        # Otherwise, apply the PSM variant of FIM transformations\n",
    "        # PSM: prefix, suffix, middle\n",
    "        else:\n",
    "\n",
    "            new_sample = np.concatenate(\n",
    "                [\n",
    "                    [prefix_tok_id],\n",
    "                    prefix,\n",
    "                    [suffix_tok_id],\n",
    "                    suffix,\n",
    "                    [middle_tok_id],\n",
    "                    middle,\n",
    "                ]\n",
    "            )\n",
    "    else:\n",
    "        # don't apply FIM transformations\n",
    "        new_sample = sample\n",
    "\n",
    "    return list(new_sample), np_rng\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AwW5FviD9xBH"
   },
   "source": [
    "让我们定义 `ConstantLengthDataset`，这是一个可迭代的数据集，它将返回固定长度的 token 块。为此，我们将从原始数据集中读取文本缓冲区，直到达到大小限制，然后应用分词器将原始文本转换为 token 后的输入。可选项，我们可以在一些序列上执行 FIM 变换（受影响的序列比例由 `fim_rate` 控制）。\n",
    "\n",
    "定义好后，我们可以从训练和验证数据中创建 `ConstantLengthDataset` 的实例。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AgDW-692wzOl"
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import IterableDataset\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "import random\n",
    "\n",
    "# Create an Iterable dataset that returns constant-length chunks of tokens from a stream of text files.\n",
    "\n",
    "class ConstantLengthDataset(IterableDataset):\n",
    "    \"\"\"\n",
    "    Iterable dataset that returns constant length chunks of tokens from stream of text files.\n",
    "        Args:\n",
    "            tokenizer (Tokenizer): The processor used for proccessing the data.\n",
    "            dataset (dataset.Dataset): Dataset with text files.\n",
    "            infinite (bool): If True the iterator is reset after dataset reaches end else stops.\n",
    "            seq_length (int): Length of token sequences to return.\n",
    "            num_of_sequences (int): Number of token sequences to keep in buffer.\n",
    "            chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.\n",
    "            fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM.\n",
    "            fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM.\n",
    "            seed (int): Seed for random number generator.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        tokenizer,\n",
    "        dataset,\n",
    "        infinite=False,\n",
    "        seq_length=1024,\n",
    "        num_of_sequences=1024,\n",
    "        chars_per_token=3.6,\n",
    "        content_field=\"content\",\n",
    "        fim_rate=0.5,\n",
    "        fim_spm_rate=0.5,\n",
    "        seed=0,\n",
    "    ):\n",
    "        self.tokenizer = tokenizer\n",
    "        self.concat_token_id = tokenizer.eos_token_id\n",
    "        self.dataset = dataset\n",
    "        self.seq_length = seq_length\n",
    "        self.infinite = infinite\n",
    "        self.current_size = 0\n",
    "        self.max_buffer_size = seq_length * chars_per_token * num_of_sequences\n",
    "        self.content_field = content_field\n",
    "        self.fim_rate = fim_rate\n",
    "        self.fim_spm_rate = fim_spm_rate\n",
    "        self.seed = seed\n",
    "\n",
    "        (\n",
    "            self.suffix_tok_id,\n",
    "            self.prefix_tok_id,\n",
    "            self.middle_tok_id,\n",
    "            self.pad_tok_id,\n",
    "        ) = get_fim_token_ids(self.tokenizer)\n",
    "        if not self.suffix_tok_id and self.fim_rate > 0:\n",
    "            print(\"FIM is not supported by tokenizer, disabling FIM\")\n",
    "            self.fim_rate = 0\n",
    "\n",
    "    def __iter__(self):\n",
    "        iterator = iter(self.dataset)\n",
    "        more_examples = True\n",
    "        np_rng = np.random.RandomState(seed=self.seed)\n",
    "        while more_examples:\n",
    "            buffer, buffer_len = [], 0\n",
    "            while True:\n",
    "                if buffer_len >= self.max_buffer_size:\n",
    "                    break\n",
    "                try:\n",
    "                    buffer.append(next(iterator)[self.content_field])\n",
    "                    buffer_len += len(buffer[-1])\n",
    "                except StopIteration:\n",
    "                    if self.infinite:\n",
    "                        iterator = iter(self.dataset)\n",
    "                    else:\n",
    "                        more_examples = False\n",
    "                        break\n",
    "            tokenized_inputs = self.tokenizer(buffer, truncation=False)[\"input_ids\"]\n",
    "            all_token_ids = []\n",
    "\n",
    "            for tokenized_input in tokenized_inputs:\n",
    "                # optionally do FIM permutations\n",
    "                if self.fim_rate > 0:\n",
    "                    tokenized_input, np_rng = permute(\n",
    "                        tokenized_input,\n",
    "                        np_rng,\n",
    "                        self.suffix_tok_id,\n",
    "                        self.prefix_tok_id,\n",
    "                        self.middle_tok_id,\n",
    "                        self.pad_tok_id,\n",
    "                        fim_rate=self.fim_rate,\n",
    "                        fim_spm_rate=self.fim_spm_rate,\n",
    "                        truncate_or_pad=False,\n",
    "                    )\n",
    "\n",
    "                all_token_ids.extend(tokenized_input + [self.concat_token_id])\n",
    "            examples = []\n",
    "            for i in range(0, len(all_token_ids), self.seq_length):\n",
    "                input_ids = all_token_ids[i : i + self.seq_length]\n",
    "                if len(input_ids) == self.seq_length:\n",
    "                    examples.append(input_ids)\n",
    "            random.shuffle(examples)\n",
    "            for example in examples:\n",
    "                self.current_size += 1\n",
    "                yield {\n",
    "                    \"input_ids\": torch.LongTensor(example),\n",
    "                    \"labels\": torch.LongTensor(example),\n",
    "                }\n",
    "\n",
    "\n",
    "train_dataset = ConstantLengthDataset(\n",
    "        tokenizer,\n",
    "        train_data,\n",
    "        infinite=True,\n",
    "        seq_length=SEQ_LENGTH,\n",
    "        chars_per_token=chars_per_token,\n",
    "        content_field=DATA_COLUMN,\n",
    "        fim_rate=FIM_RATE,\n",
    "        fim_spm_rate=FIM_SPM_RATE,\n",
    "        seed=SEED,\n",
    ")\n",
    "eval_dataset = ConstantLengthDataset(\n",
    "        tokenizer,\n",
    "        valid_data,\n",
    "        infinite=False,\n",
    "        seq_length=SEQ_LENGTH,\n",
    "        chars_per_token=chars_per_token,\n",
    "        content_field=DATA_COLUMN,\n",
    "        fim_rate=FIM_RATE,\n",
    "        fim_spm_rate=FIM_SPM_RATE,\n",
    "        seed=SEED,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rxev1sk6tRW9"
   },
   "source": [
    "## 准备模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UCtWV-U42Eq_"
   },
   "source": [
    "现在数据已经准备好了，是时候加载模型了！我们将加载量化的模型。\n",
    "\n",
    "因为量化使用更少的位来表示数据，所以会减少内存使用。我们将使用 `bitsandbytes` 库来量化模型，因为它与 `transformers` 有很好的集成。我们需要做的只是定义一个 `bitsandbytes` 配置，然后在加载模型时使用它。\n",
    "\n",
    "4 比特位量化有不同的变体，但通常我们推荐使用 NF4 量化以获得更好的性能（`bnb_4bit_quant_type=\"nf4\"`）。\n",
    "\n",
    "`bnb_4bit_use_double_quant` 选项在第一次量化后添加第二次量化，以节省每个参数额外的 0.4 位。\n",
    "\n",
    "要了解更多关于量化的信息，请查看 [\"利用 bitsandbytes、4 比特位量化和 QLoRA 让 LLMs 更易于访问\" 的博客](https://huggingface.co/blog/4bit-transformers-bitsandbytes)。\n",
    "\n",
    "定义好后，将配置传递给 `from_pretrained` 方法以加载量化的模型。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XuwoX6U2DUvK"
   },
   "outputs": [],
   "source": [
    "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
    "from peft.tuners.lora import LoraLayer\n",
    "\n",
    "load_in_8bit = False\n",
    "\n",
    "# 4-bit quantization\n",
    "compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE)\n",
    "\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_quant_type=\"nf4\",\n",
    "    bnb_4bit_compute_dtype=compute_dtype,\n",
    "    bnb_4bit_use_double_quant=USE_NESTED_QUANT,\n",
    ")\n",
    "\n",
    "device_map = {\"\": 0}\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "        MODEL,\n",
    "        load_in_8bit=load_in_8bit,\n",
    "        quantization_config=bnb_config,\n",
    "        device_map=device_map,\n",
    "        use_cache=False,  # We will be using gradient checkpointing\n",
    "        trust_remote_code=True,\n",
    "        use_flash_attention_2=True,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bO9e2FV8D8ZF"
   },
   "source": [
    "当使用量化模型进行训练时，你需要调用 `prepare_model_for_kbit_training()` 函数来预处理量化模型以进行训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Qb_eB4xzEDBk"
   },
   "outputs": [],
   "source": [
    "model = prepare_model_for_kbit_training(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lmnLjPZpDVtg"
   },
   "source": [
    "现在量化模型已经准备好了，我们可以设置一个 LoRA 配置。LoRA 通过大幅减少可训练参数的数量，使得微调更加高效。\n",
    "\n",
    "要使用 LoRA 技术训练模型，我们需要将基础模型包装为 `PeftModel`。这涉及到使用 `LoraConfig` 定义 LoRA 配置，并使用 `get_peft_model()` 和 `LoraConfig` 包装原始模型。\n",
    "\n",
    "要了解更多关于 LoRA 及其参数的信息，请参考 [PEFT 文档](https://huggingface.co/docs/peft/main/en/conceptual_guides/lora)。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "_pAUU2FR2Gey",
    "outputId": "63328c2b-e693-49b1-ce0a-3ca8722f852a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 5,554,176 || all params: 1,142,761,472 || trainable%: 0.4860310866343243\n"
     ]
    }
   ],
   "source": [
    "# Set up lora\n",
    "peft_config = LoraConfig(\n",
    "    lora_alpha=LORA_ALPHA,\n",
    "    lora_dropout=LORA_DROPOUT,\n",
    "    r=LORA_R,\n",
    "    bias=\"none\",\n",
    "    task_type=\"CAUSAL_LM\",\n",
    "    target_modules=LORA_TARGET_MODULES.split(\",\"),\n",
    ")\n",
    "\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tHe7AElXzXVV"
   },
   "source": [
    "可以看到，通过应用 LoRA 技术，我们现在只需要训练不到 1% 的参数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "T_CqVydc40IM"
   },
   "source": [
    "## 训练模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q_iN2khjrbD3"
   },
   "source": [
    "现在我们已经准备好了数据，并且优化了模型，我们可以将所有东西整合在一起开始训练。\n",
    "\n",
    "要实例化一个 `Trainer`，你需要定义训练配置。最重要的是 `TrainingArguments`，这是一个包含所有用于配置训练的属性的类。\n",
    "\n",
    "这些与你可能运行的任何其他类型的模型训练相似，所以我们这里不会详细说明。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "65QHS8l1tKQe"
   },
   "outputs": [],
   "source": [
    "train_data.start_iteration = 0\n",
    "\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=f\"Your_HF_username/{OUTPUT_DIR}\",\n",
    "    dataloader_drop_last=True,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    max_steps=MAX_STEPS,\n",
    "    eval_steps=EVAL_FREQ,\n",
    "    save_steps=SAVE_FREQ,\n",
    "    logging_steps=LOG_FREQ,\n",
    "    per_device_train_batch_size=BATCH_SIZE,\n",
    "    per_device_eval_batch_size=BATCH_SIZE,\n",
    "    learning_rate=LR,\n",
    "    lr_scheduler_type=LR_SCHEDULER_TYPE,\n",
    "    warmup_steps=NUM_WARMUP_STEPS,\n",
    "    gradient_accumulation_steps=GR_ACC_STEPS,\n",
    "    gradient_checkpointing=True,\n",
    "    fp16=FP16,\n",
    "    bf16=BF16,\n",
    "    weight_decay=WEIGHT_DECAY,\n",
    "    push_to_hub=True,\n",
    "    include_tokens_per_second=True,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kB_fLRex09ut"
   },
   "source": [
    "最后一步，实例化 `Trainer` 并调用 `train` 方法。   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "rS3nVwhUC69O",
    "outputId": "61a5bdb2-b7d0-4aed-8290-4bf20c2ccd38"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='2000' max='2000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [2000/2000 4:16:10, Epoch 1/9223372036854775807]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>5.524600</td>\n",
       "      <td>7.456872</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>5.617800</td>\n",
       "      <td>7.262190</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>5.129100</td>\n",
       "      <td>6.410039</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>5.052200</td>\n",
       "      <td>6.306774</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>5.202900</td>\n",
       "      <td>6.117062</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>4.654100</td>\n",
       "      <td>6.018349</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>5.100200</td>\n",
       "      <td>6.000355</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>5.049800</td>\n",
       "      <td>5.889457</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>4.541200</td>\n",
       "      <td>5.813823</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>5.000700</td>\n",
       "      <td>5.834208</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>5.026500</td>\n",
       "      <td>5.781939</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>4.411800</td>\n",
       "      <td>5.720596</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>4.782500</td>\n",
       "      <td>5.736376</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>4.980200</td>\n",
       "      <td>5.712276</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>4.368700</td>\n",
       "      <td>5.689637</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>4.884700</td>\n",
       "      <td>5.675920</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>4.914400</td>\n",
       "      <td>5.662421</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>4.248700</td>\n",
       "      <td>5.660122</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>4.798400</td>\n",
       "      <td>5.664026</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>4.704200</td>\n",
       "      <td>5.655665</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=2000, training_loss=4.885598585128784, metrics={'train_runtime': 15380.3075, 'train_samples_per_second': 2.081, 'train_steps_per_second': 0.13, 'train_tokens_per_second': 4261.033, 'total_flos': 4.0317260660736e+17, 'train_loss': 4.885598585128784, 'epoch': 1.0})"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer = Trainer(\n",
    "    model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset\n",
    ")\n",
    "\n",
    "print(\"Training...\")\n",
    "trainer.train()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aAERlCnt1PEW"
   },
   "source": [
    "最后，你可以将微调好的模型推送到你的 Hub 仓库中，并分享给你的团队。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1h7_AUTTDwE1"
   },
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KBVH7uFOM_UF"
   },
   "source": [
    "## 推理\n",
    "\n",
    "一旦模型被上传到 Hub，我们就可以使用它进行推理。为此，我们首先初始化原始的基础模型及其分词器。接下来，我们需要将微调后的权重与基础模型合并。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jtL37piINBFe"
   },
   "outputs": [],
   "source": [
    "from peft import PeftModel\n",
    "import torch\n",
    "\n",
    "# load the original model first\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)\n",
    "base_model = AutoModelForCausalLM.from_pretrained(\n",
    "    MODEL,\n",
    "    quantization_config=None,\n",
    "    device_map=None,\n",
    "    trust_remote_code=True,\n",
    "    torch_dtype=torch.bfloat16,\n",
    ").cuda()\n",
    "\n",
    "# merge fine-tuned weights with the base model\n",
    "peft_model_id = f\"Your_HF_username/{OUTPUT_DIR}\"\n",
    "model = PeftModel.from_pretrained(base_model, peft_model_id)\n",
    "model.merge_and_unload()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3USQ2suvDi9M"
   },
   "source": [
    "现在我们可以使用合并后的模型进行推理。为了方便起见，我们将定义一个 `get_code_completion` 函数 - 请随意尝试文本生成参数！\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RoTGpNbjDeWI"
   },
   "outputs": [],
   "source": [
    "def get_code_completion(prefix, suffix):\n",
    "    text = prompt = f\"\"\"{prefix}{suffix}\"\"\"\n",
    "    model.eval()\n",
    "    outputs = model.generate(\n",
    "        input_ids=tokenizer(text, return_tensors=\"pt\").input_ids.cuda(),\n",
    "        max_new_tokens=128,\n",
    "        temperature=0.2,\n",
    "        top_k=50,\n",
    "        top_p=0.95,\n",
    "        do_sample=True,\n",
    "        repetition_penalty=1.0,\n",
    "    )\n",
    "    return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0kMJiGDfDrBf"
   },
   "source": [
    "现在，为了获得代码补全，我们只需要调用 `get_code_complete` 函数，并将我们希望补全的前几行作为前缀传递，以及一个空字符串作为后缀。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "nXlco2_-YcvM",
    "outputId": "41c411ad-b7dc-4277-f975-c173888234bb"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "from peft import LoraConfig, TaskType, get_peft_model\n",
      "from transformers import AutoModelForCausalLM\n",
      "peft_config = LoraConfig(\n",
      "    task_type=TaskType.CAUSAL_LM,\n",
      "    r=8,\n",
      "    lora_alpha=32,\n",
      "    target_modules=[\"q_proj\", \"v_proj\"],\n",
      "    lora_dropout=0.1,\n",
      "    bias=\"none\",\n",
      "    modules_to_save=[\"q_proj\", \"v_proj\"],\n",
      "    inference_mode=False,\n",
      ")\n",
      "model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
      "model = get_peft_model(model, peft_config)\n",
      "model.print_trainable_parameters()\n"
     ]
    }
   ],
   "source": [
    "prefix = \"\"\"from peft import LoraConfig, TaskType, get_peft_model\n",
    "from transformers import AutoModelForCausalLM\n",
    "peft_config = LoraConfig(\n",
    "\"\"\"\n",
    "suffix =\"\"\"\"\"\"\n",
    "\n",
    "print(get_code_completion(prefix, suffix))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ql2563kGlnmu"
   },
   "source": [
    "作为刚刚在这个 notebook 中使用过 PEFT 库的人，你可以看到创建为 `LoraConfig` 函数的生成结果相当不错！\n",
    "\n",
    "如果你回到我们为推理实例化模型的单元格，并注释掉我们合并微调权重的行，你可以看到原始模型对于完全相同的前缀会生成什么内容："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "29xxp1eHTgJ9",
    "outputId": "c6d597a2-01da-4d25-a32f-3a551212c5b4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "from peft import LoraConfig, TaskType, get_peft_model\n",
      "from transformers import AutoModelForCausalLM\n",
      "peft_config = LoraConfig(\n",
      "    model_name_or_path=\"facebook/wav2vec2-base-960h\",\n",
      "    num_labels=1,\n",
      "    num_features=1,\n",
      "    num_hidden_layers=1,\n",
      "    num_attention_heads=1,\n",
      "    num_hidden_layers_per_attention_head=1,\n",
      "    num_attention_heads_per_hidden_layer=1,\n",
      "    hidden_size=1024,\n",
      "    hidden_dropout_prob=0.1,\n",
      "    hidden_act=\"gelu\",\n",
      "    hidden_act_dropout_prob=0.1,\n",
      "    hidden\n"
     ]
    }
   ],
   "source": [
    "prefix = \"\"\"from peft import LoraConfig, TaskType, get_peft_model\n",
    "from transformers import AutoModelForCausalLM\n",
    "peft_config = LoraConfig(\n",
    "\"\"\"\n",
    "suffix =\"\"\"\"\"\"\n",
    "\n",
    "print(get_code_completion(prefix, suffix))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Pwy2ZC7U8Ema"
   },
   "source": [
    "尽管这是 Python 语法，但你可以看到原始模型并不理解 `LoraConfig` 应该做什么。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CATYE8pp2drQ"
   },
   "source": [
    "要了解这种高效参数微调与完全微调的比较，以及如何通过推理端点在 VS Code 中使用这样的模型作为你的编程助手(copilot)，或者在本地使用，请查看[\"个人编程助手(copilot)：训练你自己的编码助手\"博客](https://huggingface.co/blog/personal-copilot)。这个 notebook 补充了原始博客内容。\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
